package org.snlab.runtime;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import com.google.protobuf.ByteString;

import org.snlab.proto.RuntimeGrpc;
import org.snlab.proto.Ddmpt.Msg;
import org.snlab.proto.Ddmpt.Msgs;
import org.snlab.proto.Ddmpt.Property;
import org.snlab.proto.Ddmpt.To;
import org.snlab.proto.Ddmpt.Vnode;
import org.snlab.proto.Ddmpt.Vnodes;
import org.snlab.proto.RuntimeGrpc.RuntimeBlockingStub;
import org.snlab.util.Replay;

import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;

public class Vrouter {
    private Daemon daemon;
    private ConcurrentHashMap<Long, Vnode> vidToVnode = new ConcurrentHashMap<>();
    private ConcurrentHashMap<String, RuntimeBlockingStub> addrToStub = new ConcurrentHashMap<>();

    // vid -> (from -> msg)
    private Map<Long, Map<Long, Msg>> vidToMsgs = new HashMap<>();
    private Map<Long, Map<Long, Message>> vidToMessages = new HashMap<>();

    private long ts = 0;
    private long cputs = 0;

    Runtime runtime = Runtime.getRuntime();
    ThreadMXBean mxBean = ManagementFactory.getThreadMXBean();

    public void loadTrace() {
        try {
            Vnodes vnodes = Vnodes.parseFrom(ByteStreams.toByteArray(new FileInputStream(
                    daemon.getConfig().getEnvDir() + "/gen/trace/" + daemon.getConfig().getName() + ".vnodes")));
            // Msgs msgs = Msgs.parseFrom(ByteStreams.toByteArray(new FileInputStream(
            // daemon.getConfig().getEnvDir() + "/gen/trace/" + daemon.getConfig().getName()
            // + ".msgs")));
            add(vnodes);
            // for (Msg msg : msgs.getMsgList()) {
            // int hs = daemon.getBddEngine().deserialize(msg.getHs().toByteArray());
            // Message message = new Message(msg, hs);
            // if (Replay.fakeStub != null) {
            // route(message, hs);
            // } else {
            // vidToMessages.get(msg.getTo()).put(msg.getFrom(), message);
            // }
            // }
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    public void replayTrace() {
        try {
            Msgs msgs = Msgs.parseFrom(ByteStreams.toByteArray(new FileInputStream(
                    daemon.getConfig().getEnvDir() + "/gen/trace/" + daemon.getConfig().getName() + ".msgs")));
            for (Msg msg : msgs.getMsgList()) {
                int hs = daemon.getBddEngine().deserialize(msg.getHs().toByteArray());
                Message message = new Message(msg, hs);
                if (Replay.fakeStub != null) {
                    route(message, hs);
                }
            }
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    public void load() {
        try {
            Vnodes vnodes = Vnodes.parseFrom(ByteStreams.toByteArray(new FileInputStream(
                    daemon.getConfig().getEnvDir() + "/gen/trace/" + daemon.getConfig().getName() + ".vnodes")));
            Msgs msgs = Msgs.parseFrom(ByteStreams.toByteArray(new FileInputStream(
                    daemon.getConfig().getEnvDir() + "/gen/trace/" + daemon.getConfig().getName() + ".msgs")));
            add(vnodes);
            for (Msg msg : msgs.getMsgList()) {
                int hs = daemon.getBddEngine().deserialize(msg.getHs().toByteArray());
                Message message = new Message(msg, hs);
                if (Replay.fakeStub != null) {
                    route(message, hs);
                } else {
                    vidToMessages.get(msg.getTo()).put(msg.getFrom(), message);
                }
            }
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    public void resetTs() {
        ts = System.nanoTime();
    }

    public void resetCpuTs() {
        cputs = mxBean.getCurrentThreadCpuTime();
    }

    public long getDuration() {
        return System.nanoTime() - ts;
    }

    public long getCpuTime() {
        return mxBean.getCurrentThreadCpuTime() - cputs;
    }

    public long getMemory() {
        System.gc();
        return runtime.totalMemory() - runtime.freeMemory();
    }

    public void writeProfile() {
        long cpuTime = getCpuTime();
        long totalTime = getDuration();
        double cpuUsage = 1.0 * cpuTime / (totalTime * runtime.availableProcessors());
        long mem = getMemory();
        String str = cputs + " " + cpuTime + " " + totalTime + " " + cpuUsage + " " + mem + "\n";
        Daemon.logger.log(daemon.getConfig().getEnvDir() + "/gen/stat.txt", str);
    }

    public void setDaemon(Daemon daemon) {
        this.daemon = daemon;
    }

    public void add(Vnode vnode) {
        vidToVnode.put(vnode.getVid(), vnode);
        vidToMsgs.put(vnode.getVid(), new HashMap<>());
        vidToMessages.put(vnode.getVid(), new HashMap<>());
    }

    public void add(Vnodes vnodes) {
        for (Vnode vnode : vnodes.getVnodeList()) {
            vidToVnode.put(vnode.getVid(), vnode);
            vidToMsgs.put(vnode.getVid(), new HashMap<>());
            vidToMessages.put(vnode.getVid(), new HashMap<>());
        }
    }

    public Collection<Vnode> getVnodes() {
        return vidToVnode.values();
    }

    /**
     * Route received msg with decoded header space
     * 
     * @param msg received msg
     * @param hs  decoded header space
     */
    public void route(Message message, int hs) {
        Msg msg = message.getMsg();
        vidToMessages.get(msg.getTo()).put(msg.getFrom(), message);
        vidToMsgs.get(msg.getTo()).put(msg.getFrom(), msg);

        // the src node do not have prev vnodes, thus we have >
        if (vidToMsgs.get(msg.getTo()).size() >= vidToVnode.get(msg.getTo()).getFromCount()) {
            // Map<Property, Integer> propToHs = merge(msg.getTo());
            if (vidToVnode.get(msg.getTo()).getToList().size() == 0) {
                int count = 0;
                count = vidToMsgs.get(msg.getTo()).values().stream().map(m -> m.getProperty()).map(p -> p.getCount())
                        .reduce(0, Integer::sum);
                Daemon.logger.log("recv msgs at " + daemon.getConfig().getName() + ": "
                        + vidToMsgs.get(msg.getTo()).size() + " " + count);
            }
            Map<String, List<PropTuple<Property>>> mapResult = map(msg.getTo(), hs);
            // reduce(mapResult);
        }
    }

    public void reroute(int change) {
        boolean foundErr = false;
        for (Map.Entry<Long, Map<Long, Message>> entry : vidToMessages.entrySet()) {
            for (Message message : entry.getValue().values()) {
                if (daemon.getBddEngine().and(change, message.getHs()) != 0) {

                    // local verification for ECMP
                    List<To> tos = vidToVnode.get(entry.getKey()).getToList();
                    int toHsUnion = 0;
                    for (To to : tos) {
                        int portHs = daemon.getDIB().getPortToPred().getOrDefault(to.getEport(), 0);
                        // toHsUnion = daemon.getBddEngine().or(toHsUnion, portHs);
                        if (!daemon.getBddEngine().subset(message.getHs(), portHs)) {
                            if (!foundErr) {
                                Daemon.logger.log("violation  found time: " + getDuration());
                                foundErr = true;
                            }
                            // writeProfile();
                            if (Replay.fakeStub != null) {
                                Replay.fakeStub.process(message.getMsg());
                            }
                            return;
                        }
                    }
                }
            }
        }
        // Daemon.logger.log("violation  found time: " + getDuration());

        remap();
    }

    private void remap() {

    }

    private Map<Property, Integer> merge(long vid) {
        Map<Property, Integer> propToHs = new HashMap<>();
        for (Msg msg : vidToMsgs.get(vid).values()) {
            Property prop = msg.getProperty();
            if (propToHs.containsKey(prop)) {
                propToHs.put(prop, this.daemon.getDIB().getBDDEngine().or(msg.getBdd(), propToHs.get(prop)));
            } else {
                propToHs.put(prop, msg.getBdd());
            }
        }
        return propToHs;
    }

    private Map<String, List<PropTuple<Property>>> map(long vid, int hs) {
        Map<String, List<PropTuple<Property>>> eportToPropTuples = new HashMap<>();
        BDDEngine bddEngine = daemon.getDIB().getBDDEngine();
        List<To> tos = vidToVnode.get(vid).getToList();

        for (To to : tos) {
            eportToPropTuples.putIfAbsent(to.getEport(), new ArrayList<>());
            List<PropTuple<Property>> propTuples = new ArrayList<>();
            for (Msg msg : vidToMsgs.get(vid).values()) {
                String port = to.getEport();
                int portHs = daemon.getDIB().getPortToPred().getOrDefault(port, 0);
                // System.out.println(portHs);
                int intersection = bddEngine.and(hs, portHs);
                PropTuple<Property> propTuple = new PropTuple<>(intersection, msg.getProperty());
                eportToPropTuples.get(to.getEport()).add(propTuple);
                propTuples.add(propTuple);
            }
            PropTuple<Property> pTuple = reduce(propTuples);

            ByteString bs = ByteString.copyFrom(bddEngine.serialize(pTuple.getHs()));
            Msg msg = Msg.newBuilder().setFrom(vid).setTo(to.getVid()).setHs(bs).setProperty(pTuple.getProp()).build();

            if (Replay.fakeStub != null) {
                // int msgSize = msg.getSerializedSize() + 54;
                // long t = System.nanoTime();
                Replay.fakeStub.process(msg);
                // long duration = System.nanoTime() - t;
                // Double bw = msgSize / (1.0 * duration / 1000000000);
                // String s = msgSize + " " + duration + " " + bw;
                // Daemon.logger.log(daemon.getConfig().getEnvDir() + "/gen/bwstat.txt", s);
            } else {
                if (addrToStub.get(to.getAddr()) == null) {

                    ManagedChannel c = ManagedChannelBuilder.forTarget(to.getAddr())
                            .usePlaintext().build();
                    RuntimeBlockingStub stub = RuntimeGrpc.newBlockingStub(c);
                    addrToStub.put(to.getAddr(), stub);
                }
                addrToStub.get(to.getAddr()).process(msg);
            }
            // c.shutdown();
        }

        return eportToPropTuples;
    }

    private PropTuple<Property> reduce(List<PropTuple<Property>> propTuples) {
        BDDEngine bddEngine = daemon.getDIB().getBDDEngine();
        // for ecmp path count
        PropTuple<Property> pTuple = new PropTuple<>();
        int hs = propTuples.stream().map(PropTuple::getHs).reduce(0, (x, y) -> bddEngine.or(x, y));
        int c = propTuples.stream().map(PropTuple::getProp).map(Property::getCount).reduce((x, y) -> x + y).get();
        pTuple.setHs(hs);
        Property prop = Property.newBuilder().setCount(c).build();
        pTuple.setProp(prop);
        return pTuple;
    }

    public void dump() {
        Vnodes vnodes = Vnodes.newBuilder().addAllVnode(daemon.getVrouter().getVnodes()).build();

        List<Msg> msgs = new ArrayList<>();
        for (Map<Long, Msg> entry : vidToMsgs.values()) {
            msgs.addAll(entry.values());
        }
        Msgs ms = Msgs.newBuilder().addAllMsg(msgs).build();

        try {
            Files.write(vnodes.toByteArray(), new File(
                    daemon.getConfig().getEnvDir() + "/gen/trace/" + daemon.getConfig().getName() + ".vnodes"));
            Files.write(ms.toByteArray(),
                    new File(daemon.getConfig().getEnvDir() + "/gen/trace/" + daemon.getConfig().getName() + ".msgs"));
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }
}
